{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.utils.data as tud\n",
    "from torch.nn.parameter import Parameter\n",
    "\n",
    "from collections import Counter\n",
    "import numpy as np\n",
    "import random\n",
    "import math\n",
    "\n",
    "import pandas as pd\n",
    "import scipy\n",
    "import sklearn\n",
    "from sklearn.metrics.pairwise import cosine_similarity"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 为了保证实验结果可以复现，我们经常会把各种random seed固定在某一个值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(53113)\n",
    "np.random.seed(53113)\n",
    "torch.manual_seed(53113)\n",
    "\n",
    "# 设定一些超参数\n",
    "    \n",
    "K = 100 # number of negative samples\n",
    "C = 3 # nearby words threshold\n",
    "NUM_EPOCHS = 2 # The number of epochs of training\n",
    "MAX_VOCAB_SIZE = 30000 # the vocabulary size\n",
    "BATCH_SIZE = 128 # the batch size\n",
    "LEARNING_RATE = 0.2 # the initial learning rate\n",
    "EMBEDDING_SIZE = 100\n",
    "\n",
    "LOG_FILE = \"word-embedding.log\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 从文本文件中读取所有的文字，通过这些文本创建一个vocabulary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "30000"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with open(\"./data/text8/text8.train.txt\", \"r\") as fin:\n",
    "    text = fin.read()\n",
    "    \n",
    "text = [w for w in text.split()]\n",
    "vocab = dict(Counter(text).most_common(MAX_VOCAB_SIZE-1))  #选取MAX_VOCAB_SIZE个单词，统计\n",
    "vocab[\"<unk>\"] = len(text) - np.sum(list(vocab.values()))  #添加一个UNK单词表示所有不常见的单词\n",
    "#print(vocab)\n",
    "idx_to_word = [word for word in vocab.keys()]  #记录单词到index的mapping\n",
    "word_to_idx = {word:i for i, word in enumerate(idx_to_word)} #index到单词的mapping\n",
    "\n",
    "word_counts = np.array([count for count in vocab.values()], dtype=np.float32)  #单词的count\n",
    "word_freqs = word_counts / np.sum(word_counts)   #np.sum(word_counts) 单词的总数\n",
    "word_freqs = word_freqs ** (3./4.)     #单词的(normalized) frequency\n",
    "word_freqs = word_freqs / np.sum(word_freqs) # 用来做 negative sampling\n",
    "VOCAB_SIZE = len(idx_to_word)\n",
    "VOCAB_SIZE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class WordEmbeddingDataset(tud.Dataset):\n",
    "    def __init__(self, text, word_to_idx, idx_to_word, word_freqs, word_counts):\n",
    "        ''' text: a list of words, all text from the training dataset\n",
    "            word_to_idx: the dictionary from word to idx\n",
    "            idx_to_word: idx to word mapping\n",
    "            word_freq: the frequency of each word\n",
    "            word_counts: the word counts\n",
    "        '''\n",
    "        super(WordEmbeddingDataset, self).__init__()\n",
    "        self.text_encoded = [word_to_idx.get(t, VOCAB_SIZE-1) for t in text]\n",
    "        self.text_encoded = torch.Tensor(self.text_encoded).long()\n",
    "        self.word_to_idx = word_to_idx\n",
    "        self.idx_to_word = idx_to_word\n",
    "        self.word_freqs = torch.Tensor(word_freqs)\n",
    "        self.word_counts = torch.Tensor(word_counts)\n",
    "    \n",
    "    def __len__(self):  #__len__ function需要返回整个数据集中有多少个item\n",
    "        ''' 返回整个数据集（所有单词）的长度\n",
    "        '''\n",
    "        return len(self.text_encoded)\n",
    "        \n",
    "    def __getitem__(self, idx): #__getitem__ 根据给定的index返回一个item\n",
    "        ''' 这个function返回以下数据用于训练\n",
    "            - 中心词\n",
    "            - 这个单词附近的(positive)单词\n",
    "            - 随机采样的K个单词作为negative sample\n",
    "        '''\n",
    "        center_word = self.text_encoded[idx]\n",
    "        pos_indices = list(range(idx-C, idx)) + list(range(idx+1, idx+C+1))\n",
    "        pos_indices = [i%len(self.text_encoded) for i in pos_indices]\n",
    "        pos_words = self.text_encoded[pos_indices] \n",
    "        neg_words = torch.multinomial(self.word_freqs, K * pos_words.shape[0], True)\n",
    "        \n",
    "        return center_word, pos_words, neg_words "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = WordEmbeddingDataset(text, word_to_idx, idx_to_word, word_freqs, word_counts)\n",
    "dataloader = tud.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)     "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 定义PyTorch模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EmbeddingModel(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_size):\n",
    "        ''' 初始化输出和输出embedding\n",
    "        '''\n",
    "        super(EmbeddingModel, self).__init__()\n",
    "        self.vocab_size = vocab_size\n",
    "        self.embed_size = embed_size\n",
    "        \n",
    "        initrange = 0.5 / self.embed_size\n",
    "        self.out_embed = nn.Embedding(self.vocab_size, self.embed_size, sparse=False)\n",
    "        self.out_embed.weight.data.uniform_(-initrange, initrange)\n",
    "        \n",
    "        \n",
    "        self.in_embed = nn.Embedding(self.vocab_size, self.embed_size, sparse=False)\n",
    "        self.in_embed.weight.data.uniform_(-initrange, initrange)\n",
    "        \n",
    "        \n",
    "    def forward(self, input_labels, pos_labels, neg_labels):\n",
    "        '''\n",
    "        input_labels: 中心词, [batch_size]\n",
    "        pos_labels: 中心词周围 context window 出现过的单词 [batch_size * (window_size * 2)]\n",
    "        neg_labelss: 中心词周围没有出现过的单词，从 negative sampling 得到 [batch_size, (window_size * 2 * K)]\n",
    "        \n",
    "        return: loss, [batch_size]\n",
    "        '''\n",
    "        \n",
    "        batch_size = input_labels.size(0)\n",
    "        \n",
    "        input_embedding = self.in_embed(input_labels) # B * embed_size\n",
    "        pos_embedding = self.out_embed(pos_labels) # B * (2*C) * embed_size\n",
    "        neg_embedding = self.out_embed(neg_labels) # B * (2*C * K) * embed_size\n",
    "      \n",
    "        log_pos = torch.bmm(pos_embedding, input_embedding.unsqueeze(2)).squeeze() # B * (2*C)\n",
    "        log_neg = torch.bmm(neg_embedding, -input_embedding.unsqueeze(2)).squeeze() # B * (2*C*K)\n",
    "\n",
    "        log_pos = F.logsigmoid(log_pos).sum(1)\n",
    "        log_neg = F.logsigmoid(log_neg).sum(1) # batch_size\n",
    "       \n",
    "        loss = log_pos + log_neg\n",
    "        \n",
    "        return -loss\n",
    "    \n",
    "    def input_embeddings(self):\n",
    "        return self.in_embed.weight.data.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = EmbeddingModel(VOCAB_SIZE, EMBEDDING_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(filename, embedding_weights): \n",
    "    if filename.endswith(\".csv\"):\n",
    "        data = pd.read_csv(filename, sep=\",\")\n",
    "    else:\n",
    "        data = pd.read_csv(filename, sep=\"\\t\")\n",
    "    human_similarity = []\n",
    "    model_similarity = []\n",
    "    for i in data.iloc[:, 0:2].index:\n",
    "        word1, word2 = data.iloc[i, 0], data.iloc[i, 1]\n",
    "        if word1 not in word_to_idx or word2 not in word_to_idx:\n",
    "            continue\n",
    "        else:\n",
    "            word1_idx, word2_idx = word_to_idx[word1], word_to_idx[word2]\n",
    "            word1_embed, word2_embed = embedding_weights[[word1_idx]], embedding_weights[[word2_idx]]\n",
    "            model_similarity.append(float(sklearn.metrics.pairwise.cosine_similarity(word1_embed, word2_embed)))\n",
    "            human_similarity.append(float(data.iloc[i, 2]))\n",
    "\n",
    "    return scipy.stats.spearmanr(human_similarity, model_similarity) # , model_similarity\n",
    "\n",
    "def find_nearest(word):\n",
    "    index = word_to_idx[word]\n",
    "    embedding = embedding_weights[index]\n",
    "    cos_dis = np.array([scipy.spatial.distance.cosine(e, embedding) for e in embedding_weights])\n",
    "    return [idx_to_word[i] for i in cos_dis.argsort()[:10]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, iter: 0, loss: 420.0462646484375\n",
      "epoch: 0, iteration: 0, simlex-999: SpearmanrResult(correlation=-0.006442402888187286, pvalue=0.8424763247082995), men: SpearmanrResult(correlation=0.015376097930182658, pvalue=0.43481460990875687), sim353: SpearmanrResult(correlation=-0.04156752024117713, pvalue=0.45940615976287436), nearest to monster: ['monster', 'researched', 'endeavor', 'waves', 'eyesight', 'idiot', 'bukhari', 'unusable', 'atrocities', 'indra']\n",
      "\n",
      "epoch: 0, iter: 100, loss: 290.9800109863281\n",
      "epoch: 0, iter: 200, loss: 206.2798309326172\n",
      "epoch: 0, iter: 300, loss: 167.51014709472656\n",
      "epoch: 0, iter: 400, loss: 143.33255004882812\n",
      "epoch: 0, iter: 500, loss: 128.27880859375\n"
     ]
    }
   ],
   "source": [
    "optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)\n",
    "for e in range(NUM_EPOCHS):\n",
    "    for i,(input_labels, pos_labels, neg_labels) in enumerate(dataloader):\n",
    "        \n",
    "        \n",
    "        # TODO\n",
    "        input_labels = input_labels.long()\n",
    "        pos_labels = pos_labels.long()\n",
    "        neg_labels = neg_labels.long()\n",
    "            \n",
    "        optimizer.zero_grad()\n",
    "        loss = model(input_labels, pos_labels, neg_labels).mean()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if i % 100 == 0:\n",
    "            with open(LOG_FILE, \"a\") as fout:\n",
    "                fout.write(\"epoch: {}, iter: {}, loss: {}\\n\".format(e, i, loss.item()))\n",
    "                print(\"epoch: {}, iter: {}, loss: {}\".format(e, i, loss.item()))\n",
    "            \n",
    "        \n",
    "        if i % 2000 == 0:\n",
    "            embedding_weights = model.input_embeddings()\n",
    "            sim_simlex = evaluate(\"./data/simlex-999.txt\", embedding_weights)\n",
    "            sim_men = evaluate(\"./data/men.txt\", embedding_weights)\n",
    "            sim_353 = evaluate(\"./data/wordsim353.csv\", embedding_weights)\n",
    "            with open(LOG_FILE, \"a\") as fout:\n",
    "                print(\"epoch: {}, iteration: {}, simlex-999: {}, men: {}, sim353: {}, nearest to monster: {}\\n\".format(\n",
    "                    e, i, sim_simlex, sim_men, sim_353, find_nearest(\"monster\")))\n",
    "                fout.write(\"epoch: {}, iteration: {}, simlex-999: {}, men: {}, sim353: {}, nearest to monster: {}\\n\".format(\n",
    "                    e, i, sim_simlex, sim_men, sim_353, find_nearest(\"monster\")))\n",
    "                \n",
    "    embedding_weights = model.input_embeddings()\n",
    "    np.save(\"embedding-{}\".format(EMBEDDING_SIZE), embedding_weights)\n",
    "    torch.save(model.state_dict(), \"embedding-{}.th\".format(EMBEDDING_SIZE))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
